{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Variational Autoencoder with a Normalizing Flow prior\n",
    "\n",
    "Using a normalizing flow as prior for the latent variables instead of the typical standard Gaussian is an easy way to make a variational autoencoder (VAE) more expressive. This notebook demonstrates how to implement a VAE with a normalizing flow as prior for the MNIST dataset. We strongly recommend to read [Pyro's VAE tutorial](vae.ipynb) first.\n",
    "\n",
    "> In this notebook we use [Zuko](https://zuko.readthedocs.io/) to implement normalizing flows, but similar results can be obtained with other PyTorch-based flow libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyro\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "import zuko\n",
    "\n",
    "from pyro.contrib.zuko import ZukoToPyro\n",
    "from pyro.optim import Adam\n",
    "from pyro.infer import SVI, Trace_ELBO\n",
    "from torch import Tensor\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms.functional import to_tensor, to_pil_image\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "The [MNIST](https://wikipedia.org/wiki/MNIST_database) dataset consists of 28 x 28 grayscale images representing handwritten digits (0 to 9)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = MNIST(root='', download=True, train=True, transform=to_tensor)\n",
    "trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAAKc0lEQVR4nO1aa1iVVRZejiJ5Q0IUqBAbDc1LYWZGmeRlaszSzPCS2sPjNGqlZY6XYNI0zVs0ojGUUNKTkWJ2AU2rCVJT8zKEt1S8oCKJoHFTRDnvWmd+HBDO+dY+xDNdZp54f328a6+9F9/7fd9ea+1DVI961KMev0sEZ5/+rUOox3+BNwrw6a+6YI9ETrzjV13RgGX2A0G/3Oxp6S5E5xnHPpgxo/HPvIzfdsbeVj/vnA19fHxmL/rkhg/s5S9brSGFAH5059//XEeD5SW2h9U5nBYBT0V5Kny7HxkPGnyCuzxtZ2bmj5Ub7hG2vfZVl5avcCYmlIqISL/aXSvR/C+j/5FweWNCQkLCvDtNg4JTwdP7qqYGa0pv+smLXUPbDk/Gr2Vm5tPruPSb+y0D7jrDKMpHaI0b02eo85AXNxgEjLhk4z51DOjmN/cCWK5Ymn1iErDLa6dywAAArPSymH3lrH9tyy4qLx3uzPicExGRwgd+UthEtESqgf1RN6uDQgEepfs3zZWnfupa19C9kCthG/vYY3dbhGja+xQYux8HR1WTs1c5jfnDiv2d9NlftqkC9ordz/zCiNheVlOnty6BTx3AOW3GGJOAqQCqBMS9FrOvSIgeYDU241+u1MRLckpEXnfrF9RpbkFBIhERHReR82lpaWlxaRkiMkgbHpzNPMQ0V7pEmUxERH+LfI/5exfS5xgzM+/YWF6iOq1y3JmINF5dTR53FvBGeU9fccAF24Gg6yz0iHPgtP0Ar3G1tHyrCMCRoFuA3tbpvL82CTgZODv/1fnz00wC2kP0CIn6fOFDRDTqQtbtFlum7BeRP5pciWhAXJEw82EiImr/UPv2AQ6+xSmRFZrDPKy/0TjbMNONJKKwZ9faAKDikIvh0befZc5oRl3iNbcehczpUzn39iH2Gnc721nATTJLXbN3rs32pIVtFFqK9L4ezTeCp7naIgAgK5B0AW84wXhJzWIaBQb6ExF55QDrrNunr0ioGiERHeHeREQH7UOttse/ExG51eRKb+8SkZK4cdZn9AmR8p6Kx47Lx28xTkeBciVA4wM25+SUgHcDACwViFeDeH7CMGNIIbC++aDI1kR88VoieluZs4A75G7VO4E5zcpGAJu8iMYAp1u72j4Djq9uS/SIKiDNYmCSIVQiIgq/CMRYaV8Ro9t3GEBEIaXqxuS/T0Q+NHi2ipcLe4Z1bGu1NH7rskh3xWUIY6GbPCXQLhMUesBJAEDHVh37ngI2WQe8xul/UCcMTuL8vY87rhlJVfSL4iSgX54Eat6+bDtvTeHmM5Z7EdFhwLoX3DDnnjZERE/pAlItAo5MA2DNYci7SJYaXObZDrYmarYa2z2sxtGLWUSmGFxjeFlz1dBvpcjVCdbXkryjGTOIiJ6PjtYcA0UmK/SXAMom9SSiOOC4r3VAs3RWMy3PVBQ/2KryiWF8U8Unyos1h62SI96Kd7sMts22sLO5/NMmRNcNLuO52poOvGMQ0M5sFnD0wXIA/26imFJNAgaeKw8johXIsdo6HaoQMe2BTedmDx6iaEREd9lE5MpDygPR4ms730009YVsZruyFeoCPlAKZDs29lRATanal5x+d1IDCx0KVBdxTgL2v0Z7DU+9IqO1SSfa+POWrqT3OXxKRB12AcnNNC+i5yKjIndja0PNZnoD283asmULAygcr36hTAJ2O4alRDTtKpRP19ByR0WgVTS0gFfr8hFFO9x2z+rmankYfPIWCvkYKD3E31p3c13AL4Ct/YmIrn+iuPLKGmsx80zL/rmDa/Qm7Lyt6jJRwono9u7TYuJKLhasL4GW8j9abNviZ2HbAG3bzNxewrA9ogXStOd6Zjtzbns1ToOA3bKryogU1Y1S1fSuUQTbeWeUZ8CuipWa13OXRUx7oF0G60sR3fNZQWUtGN3GydBiMnJfoeAkzn+/exgfUgS0awIOy0h3VLGRwD5TPdvtS+Y4l5f64cuYUv0XI7bqMo5/zMzMZKko3rF09E0e+RXKhO2YOdFKe+eBAeScQZ7i5HHXGVzMXVsK5E1XO1gmAU8yM9uZmR/S3ChVihV2DMBZwE49GCIaOGrU2GJdwF2c8yfdiYja3jEwgUVEvnbKLQYCs8kvFcWxnl0PF8da/Qx7YCUeKceVZ4xW77Fgl0o2HGevvZSeC/nL6h17ZkpKSkrKOEfuOV6OK/O9abPZtPZMr/OctaSz/2Yo37TGg4FZ95LPXgAYoTXT7MxrteiD/n5n165duy4FdAFf0AQcYSvP6xuSBjBsZ/Q3nqjBHDlmeVV6NSafOVxiri+IaPROEZEZNamZANF2IIxCAS2LCRRx03hkYLy7Fa/y1fudiHCcrLr0nIfTpi5ksiy2kiEnbLZ15rX62Nn6qHksBDZ4U+s9XD73I+Dzft0tmTgD6GyetqVJwGFSZv1gpZ8YR0Sdt4EBYwHtKXLIZVsNyDg/hsiX+R5zJETUaLOIOBXXC/ljCvmBp1BwNk/RXAJFzEXGAjuzqeVMdNsrm5gznYuJcCyrvApJwkdG12StSC6w2bbpOTYRET3IsNSADReh5JnrqedOHOlLXn9eVYLqB6gK/4Ra51VhuEnAIXI52EI+H0hE1KcIwzt3VmoPB5aIuLYb8somEdF8/sLo5MDr4uK7EB9RyBmszCnYdaN6c9wJ2HgT41m92iPqGPsDM1dsdGaHc2XRP7WQzS0eXUC22Qwd20q7VcCnUTrSZ+DaS5jtqCpHbdhgaVpM1gT0GFRZOYwrNQlIhyRON7SM5aMK3SrF0dsIKLaWEZFlIpIlJ/UTtoDZlT3xhl+JVNxX0xQKhE4sBjjfEGWgiOlT3nQ8eJW3bvOfeoKZeZdrWhWOq8tDAsNTT/PJ1XqrhYiIku3WdlmindndyZ32BuahLOMIgJfUAqISR5ntLv/jfZsQSETkM6YIuKgf1FBMiSHpj0Se9tC/L4fDOlCPUd+JvGbxnJaUn1+woaMap/8+cdROfotFZJ+TrUepo+Gu5S9E5C6JabEGeE5///z6fc/MvGOoxRwO4IfDALa9YliQiIiSJcKVCsnl8mhTmURE9IwiYCYApE7r0MjdYp8A7CLgXuCNBQsWLNjDwFfDDH4xxfp/H3TCpvYTQreLZG8oEeHvDeWqAWtEQpoQNXm5RMRe6pKSDEpj4J0pxkSl8UGTgLcCWarB58NjzMzfPKr0L276FmAgf5nVVBPJ1rb7/TbWMtNqdLOzRcAWY5dG+tV2/D1QFdABPrvC+NDEyGMqfxTv6g7RTztquQu1BOSKv4pIRnp6hohIqV50u8EeSVX5TglQykaiXutymJkvvao/ZQFzwHjdTfeciIiS7XUXkI7CzUfZDYIOWATs/g4AICtzuaXzUY2z5fqJZRR0YYk8p09PEimq609Jbv6g6ki3Yoly3FkLEsT1Zx0OJAFqBbiImQ8unO9d54VqIsL6BvpvqU3ACKS5KQfqBs8J57Fugvsj9zX7fsEf0zhHMzJ+2tat8fEjtdOI2tDu24ka3WU94swVxG8Cr8+xtm6by+8Zi3Hif0w/Iq833FXk9XBCf+XkrR7/5/gP6yvFvKvDVPUAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=448x28>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = [trainset[i][0] for i in range(16)]\n",
    "x = torch.cat(x, dim=-1)\n",
    "\n",
    "to_pil_image(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "As for the [previous tutorial](vae.ipynb), we choose a (diagonal) Gaussian model as encoder $q_\\psi(z | x)$ and a Bernoulli model as decoder $p_\\phi(x | z)$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GaussianEncoder(nn.Module):\n",
    "    def __init__(self, features: int, latent: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hyper = nn.Sequential(\n",
    "            nn.Linear(features, 1024),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(1024, 1024),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(1024, 2 * latent),\n",
    "        )\n",
    "\n",
    "    def forward(self, x: Tensor):\n",
    "        phi = self.hyper(x)\n",
    "        mu, log_sigma = phi.chunk(2, dim=-1)\n",
    "\n",
    "        return pyro.distributions.Normal(mu, log_sigma.exp()).to_event(1)\n",
    "\n",
    "\n",
    "class BernoulliDecoder(nn.Module):\n",
    "    def __init__(self, features: int, latent: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hyper = nn.Sequential(\n",
    "            nn.Linear(latent, 1024),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(1024, 1024),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(1024, features),\n",
    "        )\n",
    "\n",
    "    def forward(self, z: Tensor):\n",
    "        phi = self.hyper(z)\n",
    "        rho = torch.sigmoid(phi)\n",
    "\n",
    "        return pyro.distributions.Bernoulli(rho).to_event(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, we choose a [masked autoregressive flow](https://arxiv.org/abs/1705.07057) (MAF) as prior $p_\\phi(z)$ instead of the typical standard Gaussian $\\mathcal{N}(0, I)$. Instead of implementing the MAF ourselves, we borrow it from the [Zuko](https://github.com/probabilists/zuko) library. Because Zuko distributions are very similar to Pyro distributions, a thin wrapper (`ZukoToPyro`) is sufficient to make Zuko and Pyro 100% compatible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VAE(\n",
       "  (encoder): GaussianEncoder(\n",
       "    (hyper): Sequential(\n",
       "      (0): Linear(in_features=784, out_features=1024, bias=True)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "      (3): ReLU()\n",
       "      (4): Linear(in_features=1024, out_features=32, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (decoder): BernoulliDecoder(\n",
       "    (hyper): Sequential(\n",
       "      (0): Linear(in_features=16, out_features=1024, bias=True)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "      (3): ReLU()\n",
       "      (4): Linear(in_features=1024, out_features=784, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (prior): MAF(\n",
       "    (transform): LazyComposedTransform(\n",
       "      (0): MaskedAutoregressiveTransform(\n",
       "        (base): MonotonicAffineTransform()\n",
       "        (order): [0, 1, 2, 3, 4, ..., 11, 12, 13, 14, 15]\n",
       "        (hyper): MaskedMLP(\n",
       "          (0): MaskedLinear(in_features=16, out_features=256, bias=True)\n",
       "          (1): ReLU()\n",
       "          (2): MaskedLinear(in_features=256, out_features=256, bias=True)\n",
       "          (3): ReLU()\n",
       "          (4): MaskedLinear(in_features=256, out_features=32, bias=True)\n",
       "        )\n",
       "      )\n",
       "      (1): MaskedAutoregressiveTransform(\n",
       "        (base): MonotonicAffineTransform()\n",
       "        (order): [15, 14, 13, 12, 11, ..., 4, 3, 2, 1, 0]\n",
       "        (hyper): MaskedMLP(\n",
       "          (0): MaskedLinear(in_features=16, out_features=256, bias=True)\n",
       "          (1): ReLU()\n",
       "          (2): MaskedLinear(in_features=256, out_features=256, bias=True)\n",
       "          (3): ReLU()\n",
       "          (4): MaskedLinear(in_features=256, out_features=32, bias=True)\n",
       "        )\n",
       "      )\n",
       "      (2): MaskedAutoregressiveTransform(\n",
       "        (base): MonotonicAffineTransform()\n",
       "        (order): [0, 1, 2, 3, 4, ..., 11, 12, 13, 14, 15]\n",
       "        (hyper): MaskedMLP(\n",
       "          (0): MaskedLinear(in_features=16, out_features=256, bias=True)\n",
       "          (1): ReLU()\n",
       "          (2): MaskedLinear(in_features=256, out_features=256, bias=True)\n",
       "          (3): ReLU()\n",
       "          (4): MaskedLinear(in_features=256, out_features=32, bias=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (base): Unconditional(DiagNormal(loc: torch.Size([16]), scale: torch.Size([16])))\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self, features: int, latent: int = 16):\n",
    "        super().__init__()\n",
    "\n",
    "        self.encoder = GaussianEncoder(features, latent)\n",
    "        self.decoder = BernoulliDecoder(features, latent)\n",
    "\n",
    "        self.prior = zuko.flows.MAF(\n",
    "            features=latent,\n",
    "            transforms=3,\n",
    "            hidden_features=(256, 256),\n",
    "        )\n",
    "\n",
    "    def model(self, x: Tensor):\n",
    "        pyro.module(\"prior\", self.prior)\n",
    "        pyro.module(\"decoder\", self.decoder)\n",
    "\n",
    "        with pyro.plate(\"batch\", len(x)):\n",
    "            z = pyro.sample(\"z\", ZukoToPyro(self.prior()))\n",
    "            x = pyro.sample(\"x\", self.decoder(z), obs=x)\n",
    "\n",
    "    def guide(self, x: Tensor):\n",
    "        pyro.module(\"encoder\", self.encoder)\n",
    "\n",
    "        with pyro.plate(\"batch\", len(x)):\n",
    "            z = pyro.sample(\"z\", self.encoder(x))\n",
    "\n",
    "vae = VAE(784, 16).cuda()\n",
    "vae"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "We train our VAE with a standard stochastic variational inference (SVI) pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 96/96 [24:04<00:00, 15.05s/it, loss=63.1]\n"
     ]
    }
   ],
   "source": [
    "pyro.clear_param_store()\n",
    "\n",
    "svi = SVI(vae.model, vae.guide, Adam({'lr': 1e-3}), loss=Trace_ELBO())\n",
    "\n",
    "for epoch in (bar := tqdm(range(96))):\n",
    "    losses = []\n",
    "\n",
    "    for x, _ in trainloader:\n",
    "        x = x.round().flatten(-3).cuda()\n",
    "\n",
    "        losses.append(svi.step(x))\n",
    "\n",
    "    losses = torch.tensor(losses)\n",
    "\n",
    "    bar.set_postfix(loss=losses.sum().item() / len(trainset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can generate MNIST images by sampling latent variables from the prior and decoding them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAAMaklEQVR4nO1ae3hV1ZVf55z7SCAhDwKEAIIRQ8RHR6ihWgWUdhSwCgq1BbGOo1Xw4xtnPorV8YG11seonfqAUYYCVUdFOx8WoYjKJ2g0akB5SgmPJBDyvrnve87Ze/2YP5KQe87Z91paav8Y1j/3nvXba+2199p77bUfRKfpNJ2m05SBCoN/bwtO019B+h/5kP5NVnhtZLNxilXqI0+xQiLtlGv8W1EH0OX/ButbyODJp06dQUTF7bNOWq5w3F3/MEztpYlftElAHMn9K007tWSMmr912dXehr4EJC9RCZy7N9zW/mCxWpuu+/WAm5lohbR3fj+7HXMAuXeghx3I7z+keEzFkPPzMw/9vBA6LnUzNSIKdhxQlQ+W9ss4MrXVcYa57kwPoA8IoYf4IhfW71erBj7SkpAAgM7verUWV9si+YOMDeilyvyTCUGGz+hfNPbCdSFGKlLqRmcC4gqFVP9dlm0fqW1cq+rPu7oSXc3Ny2cNSGdq/w2AJYf35mSxZqiEeNbjPy0w4dlXtreGU6aILsnUuEE2IC9wMQNBIqKGVm9x/dpdVvUjFf3VyqYmAEC03TXAyTeqGBChQ/Ov+61AlxPLezxqs2QAYIthegy9DABgnZuhAaQHi0b/+AOTwTOUw9SXN+6Z6g2OQee77IO6Q0dTlmUKMFKfudOVwQBqvDmM9pCA6Fr/8yd2fuWdguP+BLu189Duxsj/OOoa1BRas2ji4vodnlGSRnHIGbnuJVcreb4+KYVkBhAeppb0NwOIFbq4Q0uJiBbt8Za/OMwQ8VgoWr/cEyuILksxW7YQbD/l8MOgsOSa6ToRGQ0QLkN/+u15c0t6ur7UxtVOVI8BEBbbT2cIIvp7+ztZWAz5aUU6v2hwTmnx7xraBXeFZczcmiZRHBZSCgGwYMkiOdytkoGkN8z4qqVouu++f5v/Qn3HT90iH4IjQ/sXDrt8xZfbnD0T8F2gG4HJq1QzuocuBV71ZkzBxe1SCMl2PCYhf6hsv/YmgMQ/u9nl9xER3fekp7w+34SdSJq2ZHkgzwPPaZfmJ9s7JSDS/aCttWVNQff/hznWy/Z1/7xItSdKtvnoU2eFyQAll6zyzV0wYe7LqhbQ3f7wamN13DTZwR7zbMXHdRNHBrTNjWusmzaVLeuDxm7zE3H8OH3YcLW/sDXwO1e8oJhGx4cIT01rL0oeuU6MDI64oNhY8pYzPL3xXR7eSkThlpGzNzhFbU1UiVE/GFunNJ+IiLaSXAAP1ze3UNNi5uHXNvLr36I5bx5XSFZeR4TyNjf7zmIiovmBRW4AywZdszOZM6h+8lC79l8fE8ZxR70RXbOinbpZYFQPLz98osLjB/mzaXEiIsq9U9uduSE0Hwkn4zDQ6ieioprwDuUqUFC9MqAYm9oys33PL64qNbxCvloG+Mjqy4tzvvXae7t/WTXEJX8rgLO8KoPCqvEHArNuvGNTS0Njc2E6NhLWpO5/A1tS17ttmXXbuo5oW82g3AxB5CvgTgV7XruUbS/dPSKoGdMTiKiWLd0E8D0vf/8WIqKorRDRrqkYUFpSMuvFf7l/00fv7l//j8/0Yb4QYO6/6/p/KvR1J7K91O/24T2fUwFvitNHwMOO72E24gYRFW1jyEKFgD/cpEyp8mqbbrmjcrAK0m+XgGzccGlw0KLY1jeLfS48yMAmhdwtfKSYjAHfrhh/84IZ20U0PVDGeGn3n3NMCI+d+qpEdL9lSRa1boiIyAc0q1w7KwXzqaEGEVFBE1KK9I5qAOzzsg2zkogomvBCRDln6Vqg7NUdW3fZEraRXvMNAvzZYL9i/6v3+E+3IFVKe2gLXIFkNrCnNP+JBgZgqfSGozcrh/Xlu9cc/WDjvWNUoP/K95NSpFo2hGz+osoDHwNYNdlb+JdEmtEvZ7A/J+e5JG/vgwIm30xEVNzIQLNXNGfRbSV5+ctijBcVim0gX8HWX2OEL+1ugLGX5QPeIiMAJBVNPAPDiYjYUqil4eeOzrskKlOJjgS/4By9OwFrhj/LZl17CJiaGc4BUk7OVMYJirmnChG1Au0rliq2lmflvx8RkpnvUZmhl698u5UZwHbPqMgBcC9R8XiX8/Wj5kgiI1CokUakr5fhPi9rM6RsMYWU4XqTdygq1Lp1ze4Q3lx0OrDFMPp5ui03Dv6kpKfy9RJrvB3LAPopGriU84iIOKLAKPjU4RoAVvXMSW6N/8Xg0MIsR4hXMjZmRgnuCUhBGwAQf/stE5jkEThTQNg2+LEMCiuPyPSJkk7GYhOArPAA9wOdlcNbpEzc6GifPyUKiPScHo+PYKQFaH0fAJgrA/qvPW1Ip4VW3MNrgrXwwrhV7V4KpjHERb3j63PmlZ6xthHANlU1T0f7ERFFL1eBPosBvK9yvP+PEuB6xUzppuEWGrPMzwjg2V757z/6pxG6RtpDwH96JAZcdXFAN0ZNjG/MpFYbKqYok5/cK/a3Slje3VUH8EYtA+BEUTo/x7TTw+0Q5iHpcGBmsUZE9KOsDvR7V5CzgE+Xx8HcVOAEbgbCvZt//TB4hbuN5wJQpSlU0B5fWqFR2VpltLsDQDLDUZJ/ZYSBLzNYnxNCNKNziZYDxzKjdcDvHYz01XdJSrWIdFP/RycpFs95CVG9KQS+2oN0ADEApoTl2LnpNqenEdPAqjFMNVkdSPDsTjaAX65LiA7JNzmBLcC2Xsvzo5DTXII5DEB5vHPDRxJstgt5THE4pTEQUZpORES+SRawXo2tA49WI0bZwPxXAWQ5PY4Ajv1H3uQ0t8wW6YeBRm5BussKny93KyvYx2g5r7wLuNINaQKQ4H3vJCBLHDZajlx/D44o7TwGZAkyOtx5hZbittt/vaJsVAN/5kQOAtW9/f9jhu32VRTAF8pazngr3Pz8D48CexUJyV6g/fYs84jOY+BuFfAusKTvyzErtIKyJQw8mkVtCnAsH/6Ge7W+/67c4OPpaR+5q19z6coLAS3l/lIJUeKCSJNgNo9+kmRZ7URMpO+XQ6w4aCSiVmCIEiAiosVocXHG2PybqpIA5dfzLifSArzX40B/F/hXLsE5GfK67lZ0K7jKi5QAqXsm5RVkvirT32PEz/fynwN+1Kdfd8ztXG3QjSYr43kPBRk46DDxddG5Pp+IKHeeyZ87S/97ZEpBj3snHzz8vCuOGF8wrHH+nEeVe6gIYJsMIORq4yOw+9LdcXZIbWgSGOfl9o617Wh3IQ+wfKBEI/1njLOdyJfAge5Zp38ERIucaNAEpIvnotWpMg8v0MWJh+5+adHozKmmMV4Ac9xTV1vKXBMg0sioevLA76dWudLp/Djbt2Yx5mwAzhEarLMFR7pkEvjKfWI08OkPt5lf7tpdz8zY7o495SbkvQEaawGKA+YnenYtne4xmpviE2uD0S7nqQ3tgpzr5Q7rXnQqvatEuc27p4wdewhwRVCazYicR0SkrwGk+yT1HQCb1Sb00mJFZvCClAf2toWTe24qML7jenNQdH5Q1/oHLz4gwetccv5O5rqqazY1R+qSABIvlzg9fMkbIvYMZSbtD4C9wM08e8KK/z3U/FWZatExKpcvW2+aUfMMDzSe7dDYc26KA58rBINxMMsOxeXH06bdc+zt32sfzmDpxxCz3WerRI1PakRUGFPs8l+RIiUY2OW5jNjHduu00pJHjgJiodtK62snIO2IezpGu0dyrFVKpJq+c05VyQQH9h+tB+t3dwgAEHe4BN9lWAcPnrgSDFc5D7nGHQQfCxIRZUhjygGECtTYSVOhBbsrJgHLswISEY18fPP1o1VLhB6S8jY/kb7CTCi3X0REbwBbvWNmQqzuucdqJTq8AoHaEDNHPYkW0SgbkKkkA9KzBS4D8FEmG3poe523Fec1tot9XQ3tLbfk5QYDIxwFfmKfODQ54Jq7htnrOhE+tGXVlErNsQ/RF8Uhp2W7jj0fwOZsqdPJUF6y2xZ75kkKap8IyZIBMTZjmd8CkVu8IWGLRIbTQCIin7pptwqAAVnvPWq6BjDdt2Bu+s2DitCUO/OSvKKgoZFmaEaFo8/16dHuW3X7HffBuf9jZpZd6y5QXRwQBRNA/MZsT4iKTI7v/Bp7/3yaGGGWqQf+ggGR+2CM7bYFWTYK44Fji73svAjAh0721dKFu9pjTS+UKXrm7M6Ga7/uzdVJP0DSSqY/NnvKmcZJP+bqd+zVxBg9q9jorrXeFyN/MRmFQ3x/q/dV70dv1b7R52yn6TSdpv+f9H+57ZBcSS2v2QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=448x28>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = vae.prior().sample((16,))\n",
    "x = vae.decoder(z).mean.reshape(-1, 28, 28)\n",
    "\n",
    "to_pil_image(x.movedim(0, 1).reshape(28, -1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pyro",
   "language": "python",
   "name": "pyro"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
